from src.utils import *
from src.layers import *
import csv
the_model = torch.load("./saved_model/tip-cat-example.pt")
list_1 = []
list_2 = []
k = 0
dd_list = np.loadtxt("./data/mat_drug_drug.txt")
for i in range(708):
  for j in range(i+1, 708):
    # if dd_list[i][j] == 0:
      k+=1
      list_1.append(i)
      list_2.append(j)

predict_list = [list_1, list_2]

predict_tensor = torch.tensor(predict_list)
#predict_tensor = torch.tensor(predict_list, device='cuda:0')
dd_et = torch.tensor([0]*k)
predict_result = the_model.pred(predict_tensor, dd_et)
# predict_result = predict_result
# print(predict_result)
with open("test.csv","w") as csvfile:
  writer = csv.writer(csvfile)
  for i in range(250278):
    if dd_list[list_1[i]][list_2[i]] == 1:
        writer.writerow([list_1[i], list_2[i], predict_result[i].item(),"known"])
    else:
        writer.writerow([list_1[i], list_2[i], predict_result[i].item(),"unknown"])
        
print("ok")
# t = []
# for i in range(k):
#   t.append((list_1[i], list_2[i], predict_result[i]))
# print(t)
